Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18785
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 13 PendingAs of commit 2b8e675 with merge base 651f2f2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
# Qwen 3.5 MoE — MLX Backend Support
Adds `--backend mlx` to the existing Qwen 3.5 MoE export script,
enabling export and inference on Apple Silicon via the MLX delegate.
## What changed
**Unified export** (`examples/models/qwen3_5_moe/export.py`)
- Added `--backend mlx` alongside existing CUDA path. CUDA path is
unchanged.
- Added `--model-id` for automatic HuggingFace download.
- Added `--tiny-test` for CI validation with random weights (~30s, no
download).
**MLX source transformations**
(`examples/models/qwen3_5_moe/mlx_source_transformations.py`)
- Replaces Triton-dependent modules with MLX equivalents:
`FusedMoEExperts` → `SwitchMLP`, `GatedDeltaNet` →
`mlx::gated_delta_rule` custom op, `FullAttention` → `mlx::rope`,
`KVCache` → MLX KVCache, `GemmaRMSNorm` → `F.rms_norm`, `SparseMoE` →
removes unnecessary dtype casts.
**SwitchLinear / SwitchMLP** (`backends/mlx/llm/switch.py`)
- Per-expert linear using `mlx::gather_mm` / `mlx::gather_qmm` custom
ops.
- `SwitchMLP`: reusable gated MoE MLP with configurable activation and
optional gate+up fusion.
**Gated delta rule** (`backends/mlx/model_ops/gated_delta_rule.py`)
- Custom op with `mutates_args=("state",)` for recurrent state
carry-forward.
- Pattern handler emits `MetalKernelNode` (fused GPU kernel) or
`ScanNode` (fallback), selected via `use_custom_kernel` kwarg on the op.
**New ops / schema**
- `mlx::gather_mm`, `mlx::gather_qmm`: fused gather + matmul for MoE
expert selection.
- `GatherMmNode`, `GatherQmmNode`, `ScanNode`, `MetalKernelNode`,
`ScatterAddNode` added to FlatBuffer schema + C++ runtime.
**Python runner** (`examples/models/qwen3_5_moe/run.py`)
- ExecuTorch pybinding runner with tokenizer support and vocab size
auto-detection from `.pte` metadata.
**CI** (`.github/workflows/mlx.yml`)
- `test-mlx-qwen35-moe`: tiny model export + inference with
deterministic output assertion + AsType node count check (≤23).
- `test_gated_delta_rule` tests added to `test-mlx` job.
## Usage
Export (downloads model automatically):
python export.py --model-id Qwen/Qwen3.5-35B-A3B --backend mlx --qlinear
4w --qlinear-group-size 64 --output-dir ./qwen35_moe_mlx
Run:
python -m executorch.examples.models.qwen3_5_moe.run --pte
./qwen35_moe_mlx/model.pte --tokenizer Qwen/Qwen3.5-35B-A3B --prompt
"What is the capital of France?"
CI test (no download):
python export.py --tiny-test --backend mlx --qlinear 4w --output-dir
/tmp/tiny
python -m executorch.examples.models.qwen3_5_moe.run --pte
/tmp/tiny/model.pte --prompt-len 4 --max-new-tokens 5
## Further optimization ideas:
* Write a chunked GDN kernel
* Turn off expert sorting in decode
Qwen 3.5 MoE — MLX Backend Support
Adds
--backend mlxto the existing Qwen 3.5 MoE export script, enabling export and inference on Apple Silicon via the MLX delegate.What changed
Unified export (
examples/models/qwen3_5_moe/export.py)--backend mlxalongside existing CUDA path. CUDA path is unchanged.--model-idfor automatic HuggingFace download.--tiny-testfor CI validation with random weights (~30s, no download).MLX source transformations (
examples/models/qwen3_5_moe/mlx_source_transformations.py)FusedMoEExperts→SwitchMLP,GatedDeltaNet→mlx::gated_delta_rulecustom op,FullAttention→mlx::rope,KVCache→ MLX KVCache,GemmaRMSNorm→F.rms_norm,SparseMoE→ removes unnecessary dtype casts.SwitchLinear / SwitchMLP (
backends/mlx/llm/switch.py)mlx::gather_mm/mlx::gather_qmmcustom ops.SwitchMLP: reusable gated MoE MLP with configurable activation and optional gate+up fusion.Gated delta rule (
backends/mlx/model_ops/gated_delta_rule.py)mutates_args=("state",)for recurrent state carry-forward.MetalKernelNode(fused GPU kernel) orScanNode(fallback), selected viause_custom_kernelkwarg on the op.New ops / schema
mlx::gather_mm,mlx::gather_qmm: fused gather + matmul for MoE expert selection.GatherMmNode,GatherQmmNode,ScanNode,MetalKernelNode,ScatterAddNodeadded to FlatBuffer schema + C++ runtime.Python runner (
examples/models/qwen3_5_moe/run.py).ptemetadata.CI (
.github/workflows/mlx.yml)test-mlx-qwen35-moe: tiny model export + inference with deterministic output assertion + AsType node count check (≤23).test_gated_delta_ruletests added totest-mlxjob.Usage
Export (downloads model automatically):
Run:
CI test (no download):
Further optimization ideas: